import java.util.HashMap;
import java.util.Map;

class Solution {
    public int findMaxLength(int[] nums) {
        if(nums.length<2){
            return 0;
        }
        int maxLength=0;
        Map<Integer, Integer> map = new HashMap<Integer, Integer>();
        int counter =0;
        map.put(counter,-1);
        for(int i=0;i<nums.length;i++){
            if(nums[i]==1){
                counter++;
            }else{
                counter--;
            }
            if(map.containsKey(counter)){
                int preIndex=map.get(counter);
                maxLength=Math.max(maxLength,i-preIndex);
            } else {
                map.put(counter,i);
            }
        }
        return maxLength;
}
}